library(ggplot2)
library(Rcpp)
library(ggrepel)
library(patchwork)
library(dplyr)
library(tidyr)
library(lattice)
library(latticeExtra)

theme_set(theme_bw() + theme(axis.title=element_text(size=12), strip.text=element_text(size=12),
                             axis.text=element_text(size=11), legend.text=element_text(size=11)))

set.seed(123)
rm(list=ls())

### load a few functions
Rcpp::sourceCpp("src/misc.cpp")

### subdirectories
in_dir <- "./data"
outdir <- "./output"

### load data
(loaded <-  load(file.path(in_dir, "analysis_data.RData"))) 
est <- readRDS(file.path(outdir, "rum_estimates.rds"))
 
### constants
countries <- c("CZ","DE","ES","FR","HU","IT","NL","SE","GB")
names(countries) <- countries

countries.lab <- countrycode::countrycode(countries, 'iso2c', 'country.name')
names(countries.lab) <- countries
countries.lab['GB'] <- 'England'
countries.lab <- sort(countries.lab)
sdims <- c('econ', 'cult', 'populism') 

###################################################
### plot dimension weights and complementarity ####
###################################################

## select the models for the display
names(touse) <- touse <- grep("^.4", names(est), val=TRUE)   # for larger displays

cntry.model <- grep("^A4", touse, val=TRUE)
names(cntry.model) <- substring(cntry.model, 4L)  # 1 model per country

## extract the loss matrix
loss_matrix <- lapply(cntry.model, function(m) {
  mod <- est[[m]]
  co <- mod@coef
  text <- grep("^lambda", names(co), val=TRUE)
  sorter <- regmatches(text, gregexpr("\\d+\\.?\\d*", text)) |> as.integer() 
  lam <- co[text[order(sorter)]]
  if (grepl("^A", m)) {
    return(build_psi(lam, 3)) 
  } else {
    return(build_psi(lam, 2))     
  } 
})
  
# entries of a loss matrix and labs
w_labs <- c('w_econ'=1,'w_cult'=2,'w_pop'=3)
temp <- matrix(1:9, 3,3)
cor_labs <- c('cor_econ_cult'=temp[2,1],
              'cor_econ_pop'=temp[3,1],
              'cor_cult_pop'=temp[3,2])

# calculate the relevant quantities with confidence intervals
bulk  <- lapply(touse, function(m) {
  e <- est[[m]]
  if (grepl("^A", m)) {
    nd <- 3
    w_labs_this <- w_labs
    cor_labs_this <- cor_labs
  } else {
    nd <- 2
    w_labs_this <- w_labs[c("w_econ","w_cult")]
    cor_labs_this <- cor_labs['cor_econ_cult']   
  }
  co <- e@coef
  vc <- e@vcov
  text <- grep("^lambda", names(co), val=TRUE)
  sorter <- regmatches(text, gregexpr("\\d+\\.?\\d*", text)) |> as.integer() 
  nam <- text[order(sorter)]
# simulations  
  sim <- tryCatch(MASS::mvrnorm(n=1000, mu=co[nam], Sigma=vc[nam,nam]), error=function(e) return(NULL))
  if (is.null(sim)) return(NULL)
# stack the point estimate on top of the simulated estimates
  sim <- rbind(co[nam], sim)   
  ele <-  apply(sim, 1L, function(r) {
    psi <- build_psi(r, nd)
    psi <- psi/psi[1L,1L]
    w <- sqrt(diag(psi))
    R <- diag(1/w)
    psi.hat <- -R %*% psi %*% R
    c(w,psi.hat[cor_labs_this])
  })
  estim <- ele[,1L]
  lb <- apply(ele[,2:ncol(ele)], 1L, quantile, prob=0.025)
  ub <- apply(ele[,2:ncol(ele)], 1L, quantile, prob=0.975) 
  data.frame(cntry=substring(m, 4L,5L),
             model=substring(m, 1L,2L),
             quantity = c(names(w_labs_this), names(cor_labs_this)),
             est=estim, lb=lb, ub=ub)
}) |> bind_rows()

## importance
bulk |>
  filter(model== 'A4') |>
  filter(quantity %in% c('w_cult','w_pop')) |> 
  pivot_wider(id_cols = cntry, names_from=quantity, values_from = est) |>
  mutate(pop_to_cult =w_pop/w_cult)

pic <- bulk |>
  filter(model %in% c('A4','C4')) |>
  filter(quantity %in% c('w_cult','w_pop')) |> 
  mutate(
    dim = case_match(
      substring(model, 1L, 1L),
      'A' ~ '3-dim', 'C' ~ '2-dim' 
    ),
    loc = paste0(ifelse(quantity=='w_cult','Cultural','Populist'), "\n(", dim," model)") 
  ) |> ggplot(aes(y=loc)) +  
  geom_vline(xintercept=0, color='red') +
  geom_vline(xintercept=1, color='red') + 
  geom_point(aes(x=est)) +
  geom_linerange(aes(xmin=lb, xmax=ub)) +
  facet_wrap(vars(cntry), labeller=labeller(cntry=countries.lab), nrow=3) + 
  scale_x_continuous(breaks=c(0, 0.5,1,2), trans="log1p")+
  labs(x="\nDimension weight\n(Square Root of Diagonal Entry of Loss Matrix)", 
       y=element_blank(), shape='Coordinates', linetype='Coordinates') +
  theme(legend.position='bottom')
ggsave(file.path(outdir, "figure6.pdf"), pic, height=6, width=7)

## complementarity
qlabs <- c('cor_econ_cult'='Economic & Cultural','cor_econ_pop'='Economic & Populist',
           'cor_cult_pop'='Cultural & Populist')
pic <- bulk |>
  filter(model %in% c('A4','C4')) |>
  filter(quantity %in% c('cor_econ_cult','cor_econ_pop','cor_cult_pop')) |>
  mutate(
    dim = case_match(
      substring(model, 1L, 1L),
      'A' ~ '3-dim\nmodel', 'C' ~ '2-dim\nmodel' 
    ),
    nest = est,
    nlb = lb,
    nub = ub,
    qty.f = factor(quantity, levels=names(qlabs),
                   labels = qlabs)
  ) |>  ggplot(aes(y=dim)) + 
  geom_vline(xintercept=-1, color='gray') + 
  geom_vline(xintercept=0, color='red') +
  geom_vline(xintercept=1, color='gray') +
  geom_point(aes(x=nest)) +
  geom_linerange(aes(xmin=nlb, xmax=nub)) +
  scale_x_continuous(breaks=c(-0.5,0,0.5))+
  facet_grid(cntry ~ qty.f, scales='free', 
             labeller=labeller(cntry=countries.lab)) +
  labs(x='\nComplementarity\n(Negative of Off-Diagonal Entry of Standardized Loss Matrix)', y=element_blank()) +
  theme(legend.position='bottom')
ggsave(file.path(outdir, "complementarities.pdf"), pic, height=9, width=6.5)

###################################################
################ moving 1 unit ####################
###################################################

examples <- list(
  data.frame(
    econ_lab = c("more pro-market", "same economic position", "more pro-state"),
    econ = c(1,0,-1)
  ),
  data.frame(
    cult_lab = c("more cult. conservative", "same cultural position", "more cult. liberal"),
    cult = c(1,0,-1)
  ),  
  data.frame(
    pop_lab = c("more populist", "same populist orientation", "more pluralist"),
    pop = c(1,0,-1)
  )
)
grid <- Reduce(function(x,y) merge(x,y, by=NULL), examples)
bulk <- lapply(countries, function(cn) {
  M <- loss_matrix[[cn]] / loss_matrix[[cn]][1L,1L]
  x <- as.matrix(grid[,c("econ","cult","pop")])
  data.frame(
    cntry = cn,
    grid,
    u = apply(x, 1L, function(d) -sqrt(t(d) %*% M %*% d)) 
  ) |> mutate(
    u_sca = (u-min(u))/(max(u)-min(u))
  )
})|> bind_rows() |>
  mutate(country = countries.lab[cntry],
         full_lab = paste0(pop_lab, " X ", econ_lab, " X ", cult_lab),
         text = formatC(u, digits = 1L, format="f"),
         raw_sqdist = econ*econ + pop*pop + cult*cult)

pic <- bulk |> filter(raw_sqdist==2) |>
  mutate(cons_lab = case_when(
    cult==0 ~ "same cultural position",
    econ==0 ~ "same economic position",
    pop==0 ~ "same populist orientation"
   ),
   lab = ifelse(cult==0, paste0(pop_lab, " &\n", econ_lab, " than voter"),
                ifelse(econ==0, paste0(pop_lab, " &\n", cult_lab, " than voter"), 
                       ifelse(pop==0, paste0(econ_lab, " &\n", cult_lab, " than voter"),NA)))
         ) |>
  ggplot(aes(y = lab, x = country, fill = u_sca)) +
  geom_tile(color = "white") +
  geom_text(aes(label = text), color = "white", size = 4) +
  scale_fill_viridis_c(direction=-1) + 
  theme(legend.position="none", axis.text.x = element_text(angle = 90, hjust = 1)) +
  labs(x = element_blank(), y = element_blank()) +
  facet_grid(cons_lab~., scales="free") 
ggsave(file.path(outdir, "figure7.pdf"), pic, width=7, height=7)


### directions of least/most resistance
coll <- lapply(names(loss_matrix), function(cn) {
  R <- loss_matrix[[cn]]
  M <- R/max(diag(R))
  ei <- eigen(M, symmetric=TRUE)
  D <- ei$values
  S <- ei$vectors
  lapply(seq_along(D), function(i) {
    vec <- S[,i]
    if (vec[3L]<0) {
      vec <- -1*vec
    }
    data.frame(cntry=cn, eiv = D[i], dim = sdims, val = vec) 
  }) |> bind_rows()
}) |> bind_rows()

## the direction of least resistance (major axis)
coll |> group_by(cntry, dim) |> slice_min(eiv) |> 
  mutate(val = round(val, 2)) |>
  pivot_wider(names_from = dim, values_from = val) 

### make illustrations with the indifference curves in the German case
dim.labs <- c('Pro-market\norientation', 'Cultural\nconservatism', 'Populist\norientation')

names(p) <- p <- c('AfD','CDU/CSU')
R <- loss_matrix[["DE"]]
M <- R/max(diag(R))
U <- chol(M)

## coordinates
qs <- list(
  'AfD' = pdata |> filter(cntry=='DE' & party %in% c('AfD')) ,
  'CDU/CSU' = pdata |> filter(cntry=='DE' & party %in% c('CDU/CSU'))
  ) |> lapply(function(q) { q |> select(econ, cult, populism) |> as.matrix() |> as.vector() })
voter <- (qs[['AfD']]+qs[['CDU/CSU']])/2

# u <- chol(M)
# qsproj <- lapply(qs, function(q) u %*% q )

examples <- pdata |> filter(cntry=='DE' & party %in% c('AfD','CDU/CSU')) |>
  select(party, any_of(sdims))

squt_li <- lapply(qs, function(q) {
  as.vector(t(q-voter) %*% M %*% (q-voter))
}) 

rad <- sqrt(squt_li[[1L]])

## eigen decomposition
ei <- eigen(M, symmetric=TRUE)
D <- diag(ei$values)
S <- ei$vectors

### making a 3D plot
sph <- expand.grid(list(t = seq(0,2*pi,length.out=50), 
                        k=seq(0,pi,length.out=25))) |> mutate(
                          x = rad*sin(k)*cos(t),
                          y = rad*sin(k)*sin(t),
                          z = rad*cos(k)) |> select(x,y,z) |> as.matrix()

npoints <- -sph %*% solve(t(U))
colnames(npoints) <- c("x","y","z")

npoints <- sweep(npoints, 2L, voter, '+') 
npoints_df <- as.data.frame(npoints) 
parties_df <- lapply(qs, function(q) as.data.frame(setNames(as.list(q), c("x","y","z")))) |> 
  bind_rows(.id="labels")
 
dat_vot <- setNames(as.list(voter), c("x","y","z"))
dat_vot$pch <- 15
dat_vot$cex <- 1
dat_vot$type <- "p"
dat_vot$col <- "green"
  
dat_par <- as.list(parties_df)  
dat_par$col <- "black"
dat_par$cex <- 1
dat_par$panel.3d.cloud= "panel.3dtext"

panel.cle <- function(...) {
  L_vot <- L_par <- list(...)
  L_vot[names(dat_vot)] <- dat_vot 
  L_par[names(dat_par)] <- dat_par
  
  panel.cloud(...)
  do.call("panel.cloud",L_vot)
  do.call("panel.cloud",L_par)
}
 
plt3d <- cloud(npoints[,"z"] ~ npoints[,"x"] * npoints[,"y"],
      pch=16, col="blue", 
      groups=NULL, cex=0.1,  type = "l",
      lwd = 1,
      distance = 0, 
      scales=list(arrows=FALSE, cex=0.5, col="black", font=3, tck=0.6, distance=1),
      panel = panel.cle,
      xlab = list(dim.labs[1L], rot=25), 
      ylab = list(dim.labs[2L], rot=-30), 
      zlab = list(dim.labs[3L], rot=90),
      par.settings = list(axis.line = list(col = 'transparent'))
)
pdf(file.path(outdir, "figure8A.pdf"), width=4, height=4)
print(plt3d)
dev.off()

### flat projections
axes.template <-
  list(c('econ', 'cult', 'populism'),
       c('econ', 'populism', 'cult'),
       c('cult', 'populism', 'econ')) 
axes.perm <- lapply(axes.template, match, sdims)

for (a in seq_along(axes.perm)) {
  M_this <- M[axes.perm[[a]],axes.perm[[a]]]
  f0 <- M_this[3L,3L]
  f1 <- M_this[1:2,3L]
  om <- M_this[1:2, 1:2]
  iom <- solve(om)
  iu <- solve(chol(om))
  pos.v <- voter[axes.perm[[a]]]
  pos.p <- lapply(qs, function(q) setNames(q[axes.perm[[a]]], c("x","y","z"))) |>
                    bind_rows(.id="party") |>
    mutate(
      plab = paste0(party, ' (',formatC(z, digits=1, format='f'),')') 
      )
  curves <- list()
  for (r in names(qs)) {
    pos <- qs[[r]][axes.perm[[a]]]
    z <- pos[3L] - pos.v[3L]
    names(z) <- NULL
    rad <- sqrt(squt_li[[r]] - z^2*f0 + z^2*(t(f1) %*% iom %*% f1))[1L,1L]
    points <- data.frame(t = seq(0,2*pi,length.out=100)) |>
      mutate(
        x = rad*cos(t),
        y = rad*sin(t)
      ) |> 
      select(x,y) |> 
      as.matrix()
    points_proj <- points %*% t(iu)
    shift <- pos.v[1:2]-z*iom %*% f1
    npoints.df <- sweep(points_proj, 2L, shift, '+') |> as.data.frame()
    colnames(npoints.df) <- c('x','y')
    curves[[r]] <- npoints.df |>
      mutate(
        rad = rad,
        party = r,
        ori = a
      )
  }
  curves <- bind_rows(curves)
  vlab <- paste0("(",formatC(pos.v[3L], digits=1, format='f'),")")
  pic <- ggplot(data=pos.p) +
    geom_path(data=curves, 
                aes(x=x, y=y, group=party, linetype=party),
                color='darkblue') +
      annotate("point", x=pos.v[1L], y=pos.v[2L], color="green", shape=15, size=2) +
      annotate("text_repel", x=pos.v[1L], y=pos.v[2L], 
               label=vlab, color="green",  nudge_x=-0.1, nudge_y=0.1) +
      geom_point(aes(x=x, y=y), color='darkblue') + 
      geom_text_repel(aes(x=x, y=y, label=plab),
                      color='darkblue', nudge_x=-0.1, nudge_y=0.1)  + 
      expand_limits(x=0,y=0) +
      geom_hline(yintercept=0, color='red') +
      geom_vline(xintercept=0, color='red') +  
      coord_cartesian(xlim =c(-4,4), ylim = c(-4,4)) +
      scale_linetype_manual(breaks=c('AfD','CDU/CSU'), values=c("dashed","solid")) +
      labs(x=dim.labs[axes.perm[[a]][1L]], 
           y=dim.labs[axes.perm[[a]][2L]]) +
      theme(legend.position="none", plot.title = element_text(hjust = 0.5))
  ggsave(file.path(outdir, paste0("figure8",LETTERS[a+1L],".pdf")), pic, width=3, height=3)
} 










